import os
import yaml
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.cuda import amp
import argparse
from utils.data_utils import get_sequential_data_loaders
from models.FLAMES_model import FLAMESModel

def main():
    # Argument parsing
    parser = argparse.ArgumentParser(description='Train FLAMESModel on Sequential CIFAR10/100')
    parser.add_argument('-device', default='cuda:0', help='Device to train on (e.g., "cuda:0" or "cpu")')
    parser.add_argument('-b', default=128, type=int, help='Batch size for training and testing')
    parser.add_argument('-epochs', default=64, type=int, metavar='N', help='Number of total epochs to run')
    parser.add_argument('-j', default=4, type=int, metavar='N', help='Number of data loading workers (default: 4)')
    parser.add_argument('-data-dir', type=str, required=True, help='Root directory of CIFAR10/100 dataset')
    parser.add_argument('-out-dir', type=str, default='./logs', help='Directory for saving logs and checkpoints')
    parser.add_argument('-resume', type=str, help='Resume training from the checkpoint path', default=None)
    parser.add_argument('-amp', action='store_true', help='Enable automatic mixed precision training')
    parser.add_argument('-opt', type=str, default='adam', help='Optimizer to use: "sgd" or "adam"')
    parser.add_argument('-momentum', default=0.9, type=float, help='Momentum for SGD')
    parser.add_argument('-lr', default=0.001, type=float, help='Learning rate')
    parser.add_argument('-channels', default=128, type=int, help='Channels for CSNN layers')
    parser.add_argument('-neu', type=str, help='Neuron type (specific to your model)', default='lif')
    parser.add_argument('-class-num', type=int, default=10, help='Number of classes (10 for CIFAR-10, 100 for CIFAR-100)')
    parser.add_argument('-mixup', action='store_true', help='Use MixUp data augmentation')
    parser.add_argument('-cutmix', action='store_true', help='Use CutMix data augmentation')
    parser.add_argument('-num-steps', type=int, default=10, help='Number of steps (frames) in the sequence for each sample')

    args = parser.parse_args()

    # Create output directory if it does not exist
    os.makedirs(args.out_dir, exist_ok=True)
    writer = SummaryWriter(os.path.join(args.out_dir, "tensorboard"))

    # Set random seed for reproducibility
    torch.manual_seed(2020)
    if args.device.startswith("cuda") and torch.cuda.is_available():
        torch.cuda.manual_seed(2020)

    # Load the sequential CIFAR dataset using the utility function
    train_loader, test_loader = get_sequential_data_loaders(args, num_steps=args.num_steps)

    # Initialize the model
    model = FLAMESModel(
        input_channels=3,
        num_classes=args.class_num,
        decay_rate=0.1,
        tau_d_list=[2.0, 5.0, 10.0],
        tau_s=5.0
    )
    device = torch.device(args.device)
    model.to(device)

    # Loss and optimizer setup
    criterion = torch.nn.CrossEntropyLoss()
    if args.opt.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    elif args.opt.lower() == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    else:
        raise NotImplementedError(f"Optimizer {args.opt} not supported.")

    # Learning rate scheduler (cosine annealing)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    # Mixed Precision
    scaler = amp.GradScaler() if args.amp else None

    # Resume from checkpoint if applicable
    start_epoch = 0
    best_accuracy = -1.0
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume, map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if scheduler:
                scheduler.load_state_dict(checkpoint['scheduler'])
            start_epoch = checkpoint['epoch'] + 1
            best_accuracy = checkpoint.get('best_accuracy', -1.0)
            print(f"Resumed training from checkpoint: {args.resume} (Epoch {start_epoch})")
        else:
            print(f"Checkpoint not found at: {args.resume}")

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            with amp.autocast(enabled=args.amp):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            # Backpropagation with mixed precision if enabled
            if scaler:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            # Calculate statistics
            running_loss += loss.item() * labels.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{args.epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

        # Calculate training loss and accuracy for the epoch
        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100.0 * correct / total

        # Log training metrics
        writer.add_scalar('Train/Loss', train_loss, epoch)
        writer.add_scalar('Train/Accuracy', train_accuracy, epoch)

        # Validation loop
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                test_loss += loss.item() * labels.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        test_loss /= len(test_loader.dataset)
        test_accuracy = 100.0 * correct / total

        # Log validation metrics
        writer.add_scalar('Test/Loss', test_loss, epoch)
        writer.add_scalar('Test/Accuracy', test_accuracy, epoch)

        # Step the scheduler
        scheduler.step()

        # Save the best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            checkpoint = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict() if scheduler else None,
                'epoch': epoch,
                'best_accuracy': best_accuracy
            }
            torch.save(checkpoint, os.path.join(args.out_dir, 'best_checkpoint.pth'))
            print(f"Best model saved at epoch {epoch + 1} with accuracy {test_accuracy:.2f}%")

        # Save the latest checkpoint after each epoch
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict() if scheduler else None,
            'epoch': epoch,
            'best_accuracy': best_accuracy
        }
        torch.save(checkpoint, os.path.join(args.out_dir, 'latest_checkpoint.pth'))

        print(f"Epoch [{epoch+1}/{args.epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.2f}%")

    writer.close()

if __name__ == '__main__':
    main()
